{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem: Implement a CNN for CIFAR-10\n",
    "\n",
    "### Problem Statement\n",
    "You are tasked with implementing a **Convolutional Neural Network (CNN)** for image classification on the **CIFAR-10** dataset using PyTorch. The model should contain convolutional layers for feature extraction, pooling layers for downsampling, and fully connected layers for classification. Your goal is to complete the CNN model by defining the necessary layers and implementing the forward pass.\n",
    "\n",
    "### Requirements\n",
    "1. **Define the CNN Model**:\n",
    "   - Add **convolutional layers** for feature extraction.\n",
    "   - Add **pooling layers** to reduce the spatial dimensions.\n",
    "   - Add **fully connected layers** to output class predictions.\n",
    "   - The model should be capable of processing input images of size `(32x32x3)` as in the CIFAR-10 dataset.\n",
    "\n",
    "### Constraints\n",
    "- The CNN should be designed with multiple convolutional and pooling layers followed by fully connected layers.\n",
    "- Ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.\n",
    "\n",
    "\n",
    "<details>\n",
    "  <summary>💡 Hint</summary>\n",
    "  Add the convolutional (conv1, conv2), pooling (pool), and fully connected layers (fc1, fc2) in CNNModel.__init__.\n",
    "  <br>\n",
    "  Implement the forward pass to process inputs through these layers.\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\\cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8dcf56f0d6f941e48bd710768cad07e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/170498071 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ./data\\cifar-10-python.tar.gz to ./data\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "# Load CIFAR-10 dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "])\n",
    "\n",
    "train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
    "\n",
    "test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the CNN Model\n",
    "class CNNModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNNModel, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)  # Output: 32x32x32\n",
    "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # Output: 64x32x32\n",
    "        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # Output: 64x16x16\n",
    "        self.fc1 = nn.Linear(64 * 16 * 16, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.relu(self.conv1(x))\n",
    "        x = self.pool(self.relu(self.conv2(x)))\n",
    "        x = x.view(x.size(0), -1)  # Flatten\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/10], Loss: 1.4057\n",
      "Epoch [2/10], Loss: 0.9852\n",
      "Epoch [3/10], Loss: 0.2743\n",
      "Epoch [4/10], Loss: 0.9068\n",
      "Epoch [5/10], Loss: 0.2459\n",
      "Epoch [6/10], Loss: 0.4891\n",
      "Epoch [7/10], Loss: 0.0719\n",
      "Epoch [8/10], Loss: 0.1010\n",
      "Epoch [9/10], Loss: 0.0075\n",
      "Epoch [10/10], Loss: 0.1189\n"
     ]
    }
   ],
   "source": [
    "# Initialize the model, loss function, and optimizer\n",
    "model = CNNModel()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Training loop\n",
    "epochs = 10\n",
    "for epoch in range(epochs):\n",
    "    for images, labels in train_loader:\n",
    "        # Forward pass\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    print(f\"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy: 67.59%\n"
     ]
    }
   ],
   "source": [
    "# Evaluate on the test set\n",
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for images, labels in test_loader:\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print(f\"Test Accuracy: {100 * correct / total:.2f}%\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
